import sqlparse
from sqlparse.sql import Identifier, Function, IdentifierList, Parenthesis
from sqlparse.tokens import Number, String, Keyword, DML, Punctuation

from core.tools import get_data_route, get_data_by_route


def pre_check(func):
    def decorate(sql):
        if isinstance(sql, str):
            sql = sqlparse.parse(sql)
        if isinstance(sql, tuple):
            sql = sql[0]
        return func(sql)

    return decorate


@pre_check
def get_type_table(parsed):
    """
    解析sql内感兴趣的内容
    :return: 语句类型,[表]
    """
    stream = []
    if parsed.get_type() == 'UPDATE':
        stream = list(extract_update_part(parsed))
    if parsed.get_type() == 'INSERT':
        stream = list(extract_into_part(parsed))
    stream = list(extract_from_part(parsed)) + stream

    return parsed.get_type(), list(extract_table_identifiers(stream))


def sql_val_eq(sql_val, param_val):
    return param_val == str(sql_val) or param_val == str(sql_val)[1:-1]


def get_sql_value_by_route(sql, route):
    """通过路径获取sql对应位置的值"""
    if len(route) == 0:  # 需要做出额外标志
        return sql
    return get_sql_value_by_route(sql.tokens[route[0]], route[1:])


def build_format_sql(gp, diff_sign):
    res = ""
    if diff_sign:
        if len(diff_sign[0]) == 0:  # 需要做出额外标志
            return "<font class='text-error'>{}</font>".format(str(gp))
        for i, tk in enumerate(gp.tokens):
            following = []
            for j in range(len(diff_sign) - 1, -1, -1):
                ls = diff_sign[j]
                if ls[0] == i:
                    following.append(ls[1:])
                    diff_sign.pop(j)
            if following:
                res += build_format_sql(tk, following)
            else:
                res += str(tk)
    else:
        return str(gp)
    return res


def sub_token(token_type, goal):
    return len(set(token_type) - set(goal.ttype)) == 0


def make_sql_key(gp, diff_sign):
    res = ""

    if gp.is_group:
        older_index = len(diff_sign)
        i = 0
        for g in gp.tokens:
            res += make_sql_key(g, diff_sign)
            if len(diff_sign) > older_index:
                for l in diff_sign[older_index:]:
                    l.insert(0, i)
                older_index = len(diff_sign)
            i += 1
    else:  # 非组判断
        if sub_token(Number, gp) or sub_token(String, gp):
            diff_sign.append([])
            return str(gp.ttype)
        else:
            return str(gp)
    return res


def last_identifier(sql, route, depth=None):
    if depth is None:
        depth = len(route) - 1
    if depth == 0:
        if route[0] == 0:
            goal_way = sql.tokens[1:]
        else:
            goal_way = sql.tokens[route[0] - 1::-1]
        for token in goal_way:
            if token.is_keyword:
                top = token.parent
                while top.parent:
                    top = top.parent
                _, table = get_type_table(top)
                return '|'.join([t + '-' + str(token).upper() for t in table])
            if isinstance(token, Identifier):
                return token
        return None
    last = None
    while last is None:
        depth -= 1
        if depth < 0:
            return None
        last = last_identifier(sql.tokens[route[0]], route[1:], depth)
    return last


def get_about_sign(sql, route):
    """获取相关的符号"""
    if sql.get_type() == "SELECT":
        return last_identifier(sql, route)
    elif sql.get_type() == "UPDATE":
        return last_identifier(sql, route)
    elif sql.get_type() == "INSERT":
        # 找到对应的括号，用table拼接
        pass


def str_sql_list(sql):
    if isinstance(sql, list):
        res = "".join([s.value for s in sql])
    else:
        res = sql.value
    return res


def identifier_alias(item):
    source, goal = None, None
    check_condition = [lambda x: x.match(Keyword, 'as'), lambda x: x.is_whitespace]
    for condition in check_condition:
        if source and goal:
            break
        for i, it in enumerate(item):
            if source and goal:
                break
            if condition(it):
                for j in range(i + 1, len(item.tokens)):
                    if not item[j].is_whitespace:
                        goal = item[j:]
                        break
                for j in range(i - 1, -1, -1):
                    if not item[j].is_whitespace:
                        source = item[:j + 1]
                        break
    if source is None:
        source, goal = [item], [item]
    if isinstance(source[0], Parenthesis):  # 这里是个临时表
        source, goal = source[0], goal[0]  # 直接认为只有一个元素
        res = list(get_all_alias(source))
        for s, g in res:
            yield s, "{}.{}".format(goal, g)
        yield make_sql_key(source, []), str_sql_list(goal)
    else:
        yield str_sql_list(source), str_sql_list(goal)


@pre_check
def get_obj_fields(sql):
    def obj_with_field(t_item):
        for i, it in enumerate(t_item):
            if it.ttype == Punctuation:
                yield str_sql_list(t_item[i - 1]), str_sql_list(t_item[i + 1])
            if isinstance(it, Parenthesis):
                yield from inner(it)

    def inner(inner_sql):
        if isinstance(inner_sql, Identifier):
            yield from obj_with_field(inner_sql)
        if not inner_sql.is_group:
            raise StopIteration
        for item in inner_sql:
            if isinstance(item, Identifier):
                yield from inner(item)
            if isinstance(item, IdentifierList):
                for j in item:
                    yield from inner(j)
            if item.is_group:
                yield from inner(item)

    res = {}
    for k, v in inner(sql):
        if k not in res:
            res[k] = []
        res[k].append(v)
    res = {k: list(set(res[k])) for k in res}
    return res


@pre_check
def get_all_alias(sql):
    """
    注意在 内嵌的select 中的表将会带上外面的别名，这个问题需要在外界删除前面的别名字头
    """
    if isinstance(sql, Identifier):
        yield from identifier_alias(sql)
        raise StopIteration
    if not sql.is_group:
        raise StopIteration
    for item in sql:
        if isinstance(item, Identifier):
            yield from identifier_alias(item)
        elif isinstance(item, IdentifierList):
            for j in item:
                yield from get_all_alias(j)


def check_param_same_route(data_sql, route_checkpoint, sql_variable):
    checked = {}  # {[route]->set(key)}
    warnings = []
    check_for_twice = {}
    for data, sql_key in data_sql:
        for route in get_data_route(data):
            route = tuple(route)
            val = str(get_data_by_route(data, route))
            if route not in route_checkpoint:  # 初始填充
                route_checkpoint[route] = []
                checked[route] = set()
            for (sql, key) in sql_key:
                if key not in checked[route]:
                    checked[route].add(key)
                    check_for_twice[(route, key)] = 1
                    for sql_route in sql_variable[key]:
                        sql_val = get_sql_value_by_route(sql, sql_route)
                        if sql_val_eq(sql_val, val):  # 去除引号的影响
                            route_checkpoint[route].append((key, sql_route, get_about_sign(sql, sql_route)))
                else:
                    check_for_twice[(route, key)] += 1
                    for i in range(len(route_checkpoint[route]) - 1, -1, -1):
                        t_key, t_route, _ = route_checkpoint[route][i]
                        if t_key != key:
                            continue
                        sql_val = get_sql_value_by_route(sql, t_route)
                        if not sql_val_eq(sql_val, val):
                            route_checkpoint[route].pop(i)
            for i in range(len(route_checkpoint[route])-1,-1,-1):
                key, sql_route, goal = route_checkpoint[route][i]
                if check_for_twice[(route, key)] == 1:
                    # route_checkpoint[route].pop(i)
                    from core.documents import WarningInfo
                    from core.constant import WarningLv
                    warning = WarningInfo()
                    warning.content = WarningLv.SQL_VAL_CHECK_ONE_TIME_WARNING.format(route=route)
                    warning.lv = WarningLv.WARNING
                    warnings.append(warning)
    return warnings


def extract_update_part(parsed):
    into_seen = False
    get_tabled = False
    if parsed.is_group:
        for item in parsed.tokens:
            if into_seen:
                if item.ttype is Keyword:
                    if not get_tabled:
                        get_tabled = True
                        yield item
                    else:
                        into_seen = False
                elif item.is_whitespace:
                    continue
                elif isinstance(item, Identifier):
                    get_tabled = True
                    yield item
            if item.is_keyword and item.value.upper() in ['UPDATE']:
                into_seen = True
                get_tabled = False


def is_subselect(parsed):
    if not parsed.is_group:
        return False
    for item in parsed.tokens:
        if is_subselect(item):
            return True
        if item.ttype is DML and item.value.upper() == 'SELECT':
            return True

    return False


def extract_into_part(parsed):
    into_seen = False
    get_tabled = False
    if parsed.is_group:
        for item in parsed.tokens:
            if into_seen:
                if item.ttype is Keyword:
                    if not get_tabled:
                        get_tabled = True
                        yield item
                    else:
                        into_seen = False
                elif item.is_whitespace:
                    continue
                elif isinstance(item, Identifier):
                    get_tabled = True
                    yield item
                elif isinstance(item, Function):
                    for t in item.tokens:
                        if isinstance(t, Identifier):
                            get_tabled = True
                            yield t
            if item.ttype is Keyword and item.value.upper() in ['INTO']:
                into_seen = True
                get_tabled = False


def extract_from_part(parsed):
    from_seen = False
    get_from = True
    if parsed.is_group:
        for item in parsed.tokens:
            if is_subselect(item):
                if from_seen:
                    get_from = True
                yield from extract_from_part(item)
            else:
                if from_seen:
                    if item.ttype is Keyword:
                        if not get_from:
                            get_from = True
                            yield item  # 因为有些table和keyword是一样的，所以这里再返回第一个
                        else:
                            from_seen = False
                    elif item.is_whitespace:
                        continue
                    elif isinstance(item, Identifier):
                        get_from = True
                        yield item
                    elif isinstance(item, IdentifierList):
                        get_from = True
                        yield item
                if item.ttype is Keyword and item.value.upper() in ['FROM'] or 'JOIN' in item.value.upper():
                    from_seen = True
                    get_from = False


def get_outer_identifiers_name(token_stream):
    yield from [item.get_name() for item in token_stream if isinstance(item, Identifier)]


def extract_table_identifiers(token_stream):
    for item in token_stream:
        if isinstance(item, IdentifierList):
            for identifier in item.get_identifiers():
                yield identifier.get_name()
        elif isinstance(item, Identifier):
            if len(item.tokens) > 1:
                yield from extract_table_identifiers(item.tokens)
            else:
                yield item.get_name()
        elif item.ttype == Keyword:
            yield item.value
        elif item.is_whitespace:
            raise StopIteration
        else:
            yield str(item)


def extract_definitions(token_list):
    definitions = []
    tmp = []
    tidx, tkl = token_list.token_next(1)
    while tkl and not tkl.match(sqlparse.tokens.Punctuation, ')'):
        if isinstance(tkl, IdentifierList):
            tl = list(tkl)
        else:
            tl = [tkl]
        tidx, tkl = token_list.token_next(tidx, skip_ws=False)
        for token in tl:
            if token and token.match(sqlparse.tokens.Punctuation, ','):
                definitions.append(tmp)
                tmp = []
                continue
            if token.value.strip():
                tmp.append(token)
    if tmp and isinstance(tmp[0], sqlparse.sql.Identifier):
        definitions.append(tmp)
    return definitions
